package com.ershijin.my;

import com.ershijin.util.JwtUtils;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
import org.springframework.security.crypto.keygen.StringKeyGenerator;
import org.springframework.security.oauth2.core.*;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.jwt.JoseHeader;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtClaimsSet;
import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.config.ProviderSettings;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

import java.security.Principal;
import java.time.Duration;
import java.time.Instant;
import java.util.Base64;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;

public final class OAuth2PasswordAuthenticationProvider implements AuthenticationProvider {
    private AuthenticationManager authenticationManager;

    private final ProviderSettings providerSettings;

    private final OAuth2AuthorizationService authorizationService;

    private final JwtEncoder jwtEncoder;
    private OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer = (context) -> {};

    private static final StringKeyGenerator DEFAULT_REFRESH_TOKEN_GENERATOR =
            new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
    private Supplier<String> refreshTokenGenerator = DEFAULT_REFRESH_TOKEN_GENERATOR::generateKey;


    public OAuth2PasswordAuthenticationProvider(OAuth2AuthorizationService authorizationService, JwtEncoder jwtEncoder, AuthenticationManager authenticationManager, ProviderSettings providerSettings) {
        Assert.notNull(authorizationService, "authorizationService cannot be null");
        Assert.notNull(jwtEncoder, "jwtEncoder cannot be null");
        this.jwtEncoder = jwtEncoder;
        this.authorizationService = authorizationService;

        this.authenticationManager = authenticationManager;
        this.providerSettings = providerSettings;
    }

    @Override
    public Authentication authenticate(Authentication authentication) throws AuthenticationException {
        OAuth2PasswordAuthenticationToken passwordAuthentication = (OAuth2PasswordAuthenticationToken) authentication;

        OAuth2ClientAuthenticationToken clientPrincipal = getAuthenticatedClientElseThrowInvalidClient(passwordAuthentication);
        RegisteredClient registeredClient = clientPrincipal.getRegisteredClient();
        if (!(registeredClient != null && registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.PASSWORD))) {
            throw new OAuth2AuthenticationException(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT);
        }
        Map<String, Object> additionalParameters = passwordAuthentication.getAdditionalParameters();
        String username = (String) additionalParameters.get(OAuth2ParameterNames.USERNAME);
        String password = (String) additionalParameters.get(OAuth2ParameterNames.PASSWORD);
        try {
            UsernamePasswordAuthenticationToken usernamePasswordAuthenticationToken = new UsernamePasswordAuthenticationToken(username, password);
            Authentication usernamePasswordAuthentication = authenticationManager.authenticate(usernamePasswordAuthenticationToken);
            Set<String> authorizedScopes = registeredClient.getScopes();
            if (!CollectionUtils.isEmpty(passwordAuthentication.getScopes())) {
                Set<String> unauthorizedScopes = passwordAuthentication.getScopes().stream()
                        .filter(requestedScope -> !registeredClient.getScopes().contains(requestedScope))
                        .collect(Collectors.toSet());
                if (!CollectionUtils.isEmpty(unauthorizedScopes)) {
                    throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_SCOPE);
                }
                authorizedScopes = new LinkedHashSet<>(passwordAuthentication.getScopes());
            }

            String issuer = this.providerSettings != null ? this.providerSettings.getIssuer() : null;
            JoseHeader.Builder headersBuilder = JwtUtils.headers();
//            JwtClaimsSet.Builder claimsBuilder = JwtUtils.accessTokenClaims(
//                    registeredClient, issuer, clientPrincipal.getName(), authorizedScopes);
            JwtClaimsSet.Builder claimsBuilder = JwtUtils.accessTokenClaims(
                    registeredClient, issuer, (String) passwordAuthentication.getAdditionalParameters().get("username"), authorizedScopes);
//            claimsBuilder.claim("authorities", AuthorityUtils.authorityListToSet(usernamePasswordAuthentication.getAuthorities()));

            JwtEncodingContext context = JwtEncodingContext.with(headersBuilder, claimsBuilder)
                    .registeredClient(registeredClient)
                    .principal(usernamePasswordAuthentication)
                    .authorizedScopes(authorizedScopes)
                    .tokenType(OAuth2TokenType.ACCESS_TOKEN)
                    .authorizationGrantType(AuthorizationGrantType.PASSWORD)
                    .authorizationGrant(passwordAuthentication)
//                    .put(USERNAME_PASSWORD_AUTHENTICATION_KEY, usernamePasswordAuthentication)
//                    .put(Authentication.class.getName().concat(".USERNAME_PASSWORD"), usernamePasswordAuthentication)
//                    .put("authorities", AuthorityUtils.authorityListToSet(usernamePasswordAuthentication.getAuthorities()))
                    .build();
            this.jwtCustomizer.customize(context);
            JoseHeader headers = context.getHeaders().build();
            JwtClaimsSet claims = context.getClaims().build();
            Jwt jwtAccessToken = this.jwtEncoder.encode(headers, claims);
            OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
                    jwtAccessToken.getTokenValue(), jwtAccessToken.getIssuedAt(),
                    jwtAccessToken.getExpiresAt(), authorizedScopes);
            OAuth2RefreshToken refreshToken = null;
            if (registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.REFRESH_TOKEN)) {
                refreshToken = generateRefreshToken(registeredClient.getTokenSettings().getRefreshTokenTimeToLive());
            }
            OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.withRegisteredClient(registeredClient)
                    .principalName(clientPrincipal.getName())
                    .authorizationGrantType(AuthorizationGrantType.PASSWORD)
                    .token(accessToken, (metadata) -> metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, jwtAccessToken.getClaims()))
                    .attribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME, authorizedScopes)
                    .attribute(Principal.class.getName(), usernamePasswordAuthentication);
            if (refreshToken != null) {
                authorizationBuilder.refreshToken(refreshToken);
            }
            OAuth2Authorization authorization = authorizationBuilder.build();
            this.authorizationService.save(authorization);
            return new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken, refreshToken);
        } catch (Exception ex) {
            throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR), ex);
        }
    }

    public void setJwtCustomizer(OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer) {
        Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null");
        this.jwtCustomizer = jwtCustomizer;
    }

    @Override
    public boolean supports(Class<?> authentication) {
        return OAuth2PasswordAuthenticationToken.class.isAssignableFrom(authentication);
    }

    private OAuth2RefreshToken generateRefreshToken(Duration tokenTimeToLive) {
        Instant issuedAt = Instant.now();
        Instant expiresAt = issuedAt.plus(tokenTimeToLive);
        return new OAuth2RefreshToken(this.refreshTokenGenerator.get(), issuedAt, expiresAt);
    }

    /**
     * org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationProviderUtils 里面扣出来的
     * @param authentication
     * @return
     */
    private OAuth2ClientAuthenticationToken getAuthenticatedClientElseThrowInvalidClient(Authentication authentication) {
        OAuth2ClientAuthenticationToken clientPrincipal = null;
        if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication.getPrincipal().getClass())) {
            clientPrincipal = (OAuth2ClientAuthenticationToken) authentication.getPrincipal();
        }
        if (clientPrincipal != null && clientPrincipal.isAuthenticated()) {
            return clientPrincipal;
        }
        throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_CLIENT);
    }


}
